import random

import keras
from keras import backend as K
from keras.models import Model
from keras.layers import Input, Embedding, LSTM, Dense, Reshape, Bidirectional, Lambda, Flatten, Multiply
from keras.utils import plot_model, Sequence
from gensim.models import Word2Vec, KeyedVectors

from Cat2Vec.Attention import Attention
from Cat2Vec.cat2vecUtils import *
from utils import *
from Cat2Vec.sqllite_adapter import *
import tensorflow as tf


random.seed(0)

def pretrain_gensim_word2vec(text, model_name="word2vec.model"):
    model = Word2Vec(text, size=300, window=4, min_count=1, workers=8, iter=100)
    model.save(model_name)
    return model

def vocab_dictionary_from_pretrained(model):
    word2index = {}
    for index, word in enumerate(model.wv.index2word):
        word2index[word] = index
    return word2index

def embedding_layer_from_pretrain(model):
    embedding_layer = model.wv.get_keras_embedding(train_embeddings=False)
    embedding_layer.input_length=None
    embedding_layer.name = "word_embedding"
    embedding_layer.mask_zero = True
    return embedding_layer

def create_vocab(document_texts):
    vocab = list(set(flatten(document_texts)))
    vocab.sort()
    return vocab

def vocab_dictionary(vocab):
    return {vocab[i]: i for i in range(1,len(vocab))}

def category_dictionary(categories):
    category_set = list(set(categories))
    category_set.sort()
    return {category_set[i]: i for i in range(len(category_set))}

def generate_input_words_vector(document_word_list, vocab_dict, limit=0, mask_percentage=0.0):
    document_word_list = list(filter(lambda x: x in vocab_dict, document_word_list))
    if limit == 0:
        return [vocab_dict[word] if random.random() >= mask_percentage else 0 for word in document_word_list ]
    else:
        output = [vocab_dict[word] for word in document_word_list]
        if len(output) >= limit:
            return output[:limit]
        else:
            return output + ([0] * (limit - len(output)))

def generate_input_category_vector(category, category_dict):
    return np.array([category_dict[category]])

def create_model(vocab, categories, embedding_length, word_embedding_layer=None, sentiments=None, max_len=50):

    input_words = Input(shape=(max_len,))
    if word_embedding_layer is None:
        word_embeddings = Embedding(len(vocab), embedding_length, name='word_embedding', input_length=None, trainable=True, mask_zero=True)
    else:
        word_embeddings = word_embedding_layer

    document_word_embeddings = word_embeddings(input_words)

    foward_lstm = Bidirectional(LSTM(embedding_length//2, dropout=0.2, return_sequences=True, input_shape=(None, embedding_length)))
    forward_embedding = foward_lstm(document_word_embeddings)
    forward_lstm_attention = Attention()(forward_embedding)

    doc_output = Dense(len(categories), input_dim=embedding_length, activation="softmax", name="document_output")(forward_lstm_attention)

    input_categories = Input(shape=(1,))
    category_embeddings = Embedding(len(categories), embedding_length, name='cat_embedding', trainable=True)
    document_category = category_embeddings(input_categories)

    category_importance_embeddings = Embedding(len(categories), embedding_length, name='gated_cat_embedding',
                                          trainable=True)  # embeddings_constraint=keras.constraints.MinMaxNorm(min_value=0.0, max_value=1.0, rate=1.0, axis=[0,1])

    importance_layer = Dense(embedding_length, activation="sigmoid", name="gate")
    category1 = importance_layer(category_importance_embeddings(input_categories))

    tanh = Dense(embedding_length, activation="tanh", name="tanh1")
    category_importance = tanh(category1)

    gated_document_category = Multiply(name="multiply_category")([category_importance, document_category])
    gated_document_embedding = Multiply(name="multiply_document")([category_importance, forward_lstm_attention])

    document_embedding = Reshape((embedding_length, 1), name="doc_reshape")(gated_document_embedding)
    document_category = Reshape((embedding_length, 1), name="category_reshape")(gated_document_category)

    def euclidean_squared_distance(vectors):
        x, y = vectors
        return K.sum(K.square(x - y), axis=1, keepdims=True)

    def euclidean_distance_output_shape(shapes):
        shape1, shape2 = shapes
        return shape1[0], 1

    distance = Lambda(euclidean_squared_distance, output_shape=euclidean_distance_output_shape)([document_embedding, document_category])
    dot_product = Reshape((1,))(distance)
    dot_output = Dense(1, activation='sigmoid', name="category_output")(dot_product)


    if sentiments is not None:
        losses = {
            "document_output": "categorical_crossentropy",
            "category_output": "binary_crossentropy",
        }
        model = Model(inputs=[input_words, input_categories], outputs=[dot_output, doc_output])
    else:
        losses = {
            "document_output": "categorical_crossentropy",
            "category_output": "binary_crossentropy",
        }
        model = Model(inputs=[input_words, input_categories], outputs=[dot_output, doc_output])

    model.compile(loss=losses, optimizer='adam', metrics=[keras.metrics.BinaryAccuracy()])
    model.summary()
    model._layers = [layer for layer in model._layers if not isinstance(layer, dict)]
    # plot_model(model, to_file='model.png')

    return model

class TrainingDataGenerator(Sequence):

    def __init__(self, document_texts, vocab_dictionary, categories, category_dict, sentiments=None, sentiments_dict=None, batch=256, mask_percentage=0.0, max_len=100):
        self.words = [generate_input_words_vector(text, vocab_dictionary, mask_percentage=mask_percentage) for text in document_texts]
        self.categories = [category_dict[category] for category in categories]
        self.number_categories = len(category_dict)
        self.indexes = [i for i in range(len(self.words))]
        self.batch_size = batch
        self.max_len = max_len
        self.sentiments = sentiments
        if sentiments is not None:
            self.sentiments = [sentiments_dict[sentiment] for sentiment in sentiments]

    def __len__(self):
        number_of_batches = math.ceil(np.floor(len(self.words) / self.batch_size))
        return number_of_batches

    def __getitem__(self, index):
        idx_range = range(index * self.batch_size, (index + 1) * self.batch_size)
        input_words = []
        input_random_category = []
        output_category = []
        output_true_category = []
        output_sentiments = []

        for idx in idx_range:
            document_output = [0] * self.number_categories
            document_output[self.categories[idx]] = 1

            input_words += [self.words[idx][:self.max_len]]
            input_random_category += [[self.categories[idx]]]
            output_true_category += [[1]]
            output_category += [document_output]

            for i in range(8):
                random_category = random.randint(0, self.number_categories-1)
                # idx = random.randint(0, len(self.words) - 1)

                if self.sentiments is not None:
                    output_sentiments += [[self.sentiments[idx]]]

                document_output = [0] * self.number_categories
                document_output[self.categories[idx]] = 1

                input_words += [self.words[idx][:self.max_len]]
                input_random_category += [[random_category]]
                output_true_category += [[1] if self.categories[idx] == random_category else [0]]
                output_category += [document_output]

        padded_words = tf.keras.preprocessing.sequence.pad_sequences(input_words,
                                                                      padding='post', maxlen=self.max_len)
        if self.sentiments is not None:
            return [padded_words, np.array(input_random_category)], [np.array(output_true_category), np.array(output_sentiments)]
        return [padded_words, np.array(input_random_category)], [np.array(output_true_category), np.array(output_category)]

def save_dictionary(save_dictionary, filename):
    out = open(filename, "w", encoding="utf8")
    for data in save_dictionary:
        out.write("{},{}\n".format(data, save_dictionary[data]))
    out.close()

def train_pretrained(version="v5"):

    db = create_connection("../Datasets/CategoricalNews.db")
    articles = get_articles(db)
    db.close()
    print("Load Documents")
    document_texts = [preprocess_text(article[-1]) for article in articles]
    word_count = count_word_occurences(document_texts)
    document_texts = filter_texts_by_count(document_texts, word_count, 4)
    categories = [article[-2] for article in articles]
    category_dict = category_dictionary(categories)
    print("Process Documents")
    pretrained = pretrain_gensim_word2vec(document_texts)
    pretrained = Word2Vec.load("word2vec.model")

    vocab_dict = vocab_dictionary_from_pretrained(pretrained)
    print("Train Word2Vec Model")
    model = create_model(vocab_dict, category_dict, 300, word_embedding_layer=embedding_layer_from_pretrain(pretrained), max_len=50)

    TrainingDataGenerator(document_texts, vocab_dict, categories, category_dict).__getitem__(0)

    loss = model.fit_generator(TrainingDataGenerator(document_texts, vocab_dict, categories, category_dict, max_len=50), epochs=15)

    save_dictionary(vocab_dict, "vocab_{}.txt".format(version))
    save_dictionary(category_dict, "categories_{}.txt".format(version))
    save_model(model, "cat2vec_{}".format(version))

def train_sentiment(version="w100v1"):
    db = create_connection("../Datasets/CategoricalNews.db")
    articles = get_articles(db)
    db.close()
    print("Load Documents")
    document_texts = [preprocess_text(article[-1]) for article in articles]
    word_count = count_word_occurences(document_texts)
    document_texts = filter_texts_by_count(document_texts, word_count, 5)
    sentiments = [article[-3] for article in articles]
    sentiments_dict = {"negative": 0, "positive": 1}
    categories = [article[-2] for article in articles]
    category_dict = category_dictionary(categories)
    print("Process Documents")
    # pretrained = pretrain_gensim_word2vec(document_texts)
    pretrained = Word2Vec.load("word2vec.model")
    vocab_dict = vocab_dictionary_from_pretrained(pretrained)
    print("Train Word2Vec Model")
    model = create_model(vocab_dict, category_dict, 300, word_embedding_layer=embedding_layer_from_pretrain(pretrained), sentiments=sentiments_dict, max_len=100)

    TrainingDataGenerator(document_texts, vocab_dict, categories, category_dict, sentiments=sentiments, sentiments_dict=sentiments_dict).__getitem__(0)

    loss = model.fit_generator(TrainingDataGenerator(document_texts, vocab_dict, categories, category_dict, sentiments=sentiments, sentiments_dict=sentiments_dict, max_len=100), steps_per_epoch=len(articles), epochs=5)

    save_dictionary(vocab_dict, "vocab_{}.txt".format(version))
    save_dictionary(category_dict, "categories_{}.txt".format(version))
    save_model(model, "cat2vec_{}".format(version))


if __name__ == '__main__':
    train_pretrained("w50v2")
